"""GANITE Codebase.

Reference: Jinsung Yoon, James Jordon, Mihaela van der Schaar, 
"GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets", 
International Conference on Learning Representations (ICLR), 2018.

Paper link: https://openreview.net/forum?id=ByKWUeWA-

Last updated Date: April 25th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)

-----------------------------

data_loading.py

Note: Load real-world individualized treatment effects estimation datasets

(1) data_loading_twin: Load twins data.
  - Reference: http://data.nber.org/data/linked-birth-infant-death-data-vital-statistics-data.html
"""

# Necessary packages
import numpy as np
import pandas as pd
from scipy.special import expit
import os

def data_loading(config):
  """Load twins data.
  
  Args:
    - train_rate: the ratio of training data
    
  Returns:
    - train_x: features in training data
    - train_t: treatments in training data
    - train_y: observed outcomes in training data
    - train_potential_y: potential outcomes in training data
    - test_x: features in testing data
    - test_potential_y: potential outcomes in testing data      
  """
  
  # Load original data (11400 patients, 30 features, 2 dimensional potential outcomes)

  feat_path = os.path.join(config['rootPath'], 'dataset/{}/{}{}.csv'.format(
    config['dataset'], config['dataset'], config['start_order']))
  feat = pd.read_csv(feat_path).sample(frac=1).reset_index(drop=True)
  splits: 0.63 / 0.27 / 0.1
  train_rate = float(config['splits'].strip().split('/')[0])

  # Define features
  x = feat.iloc[:,5:].values
  t = feat['treatment'].values
  y = feat['yf'].values

  potential_y = feat[['mu0','mu1']].values
  no, dim = x.shape

  ## Train/test division
  idx = np.random.permutation(no)
  train_idx = idx[:int(train_rate * no)]
  test_idx = idx[int(train_rate * no):]

  train_x = x[train_idx,:]
  train_t = t[train_idx]
  train_y = y[train_idx]

  train_potential_y = potential_y[train_idx,:]

  test_x = x[test_idx,:]
  test_potential_y = potential_y[test_idx,:]
        
  return train_x, train_t, train_y, train_potential_y, test_x, test_potential_y

